"""
Phase 1: Data Assembly — Sovereign Spreads Panel
==================================================
Merges full_panel.csv with S&P rating data (from safe_assets),
constructs sovereign spread proxies, adds fiscal determinants.

Extends the safe_asset_cliff rating panel from 31 to all countries
with bond yield data. Key innovation: we use govt_bond_10y spread
vs GDP-weighted world average as the primary dependent variable,
allowing us to test demographics on a much wider sample than
existing EMBI-only studies.

Output: sovereign_spreads/data/processed/spread_panel.csv
Tables: summary_statistics.md
"""

import sys
from pathlib import Path

import numpy as np
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

# ── Paths ──────────────────────────────────────────────────────────────────
PROJECT_DIR = Path("/mnt/c/demographics_capital_flows/sovereign_spreads")
ROOT_DIR = PROJECT_DIR.parent
MULTILATERAL_DIR = ROOT_DIR / "multilateral"
SAFE_DIR = ROOT_DIR / "safe_assets"
FISCAL_DIR = ROOT_DIR / "fiscal_dominance"
PROCESSED_DIR = PROJECT_DIR / "data" / "processed"
TABLES_DIR = PROJECT_DIR / "output" / "tables"

for d in [PROCESSED_DIR, TABLES_DIR]:
    d.mkdir(parents=True, exist_ok=True)

sys.path.insert(0, str(MULTILATERAL_DIR / "src"))
sys.path.insert(0, str(SAFE_DIR / "scripts"))
from phase1_data_assembly import RATING_HISTORY, RATING_SCALE, SAFE_THRESHOLD, build_ratings_panel

OECD_38 = [
    "AUS", "AUT", "BEL", "CAN", "CHL", "COL", "CRI", "CZE", "DNK", "EST",
    "FIN", "FRA", "DEU", "GRC", "HUN", "ISL", "IRL", "ISR", "ITA", "JPN",
    "KOR", "LVA", "LTU", "LUX", "MEX", "NLD", "NZL", "NOR", "POL", "PRT",
    "SVK", "SVN", "ESP", "SWE", "CHE", "TUR", "GBR", "USA",
]

# ── Extended rating history for major EM economies ────────────────────────
# These countries were never in safe_assets (below AA-) but have S&P ratings.
# Source: S&P Global Ratings historical actions through end-2024.
RATING_HISTORY_EM = {
    'CHN': [(2001, 2010, 14), (2010, 2017, 20), (2017, 2024, 17)],  # BBB → AA+ → A+
    'BRA': [(1994, 2008, 9), (2008, 2011, 13), (2011, 2015, 13),
            (2015, 2018, 9), (2018, 2024, 10)],  # BB- → BBB → BB
    'MEX': [(1990, 2000, 10), (2000, 2009, 13), (2009, 2024, 13)],  # BBB range
    'IND': [(1990, 2007, 10), (2007, 2024, 13)],  # BBB-
    'RUS': [(1996, 1998, 8), (1998, 2001, 3), (2001, 2005, 10),
            (2005, 2008, 13), (2008, 2014, 13), (2014, 2022, 13),
            (2022, 2024, 3)],  # SD after sanctions
    'TUR': [(1994, 2004, 7), (2004, 2013, 10), (2013, 2024, 8)],  # B+
    'ZAF': [(1994, 2000, 11), (2000, 2012, 13), (2012, 2017, 13),
            (2017, 2024, 10)],  # BBB → BB
    'IDN': [(1994, 1997, 13), (1997, 1999, 3), (1999, 2006, 7),
            (2006, 2012, 10), (2012, 2024, 13)],  # BBB
    'ARG': [(1993, 2001, 10), (2001, 2003, 0), (2003, 2006, 0),
            (2006, 2010, 7), (2010, 2014, 8), (2014, 2020, 3),
            (2020, 2024, 3)],  # serial defaulter
    'COL': [(1995, 2000, 13), (2000, 2011, 10), (2011, 2021, 13),
            (2021, 2024, 10)],  # BB+ → BBB → BB+
    'PER': [(1997, 2008, 10), (2008, 2024, 13)],  # BBB
    'PHL': [(1993, 2003, 10), (2003, 2013, 10), (2013, 2024, 13)],
    'THA': [(1990, 1997, 17), (1997, 2000, 10), (2000, 2004, 11),
            (2004, 2024, 13)],
    'MYS': [(1990, 1997, 17), (1997, 2000, 13), (2000, 2004, 17),
            (2004, 2024, 17)],  # A-
    'EGY': [(1997, 2011, 13), (2011, 2013, 8), (2013, 2024, 7)],  # B
    'NGA': [(2006, 2024, 7)],  # B range
    'PAK': [(1994, 1998, 8), (1998, 2004, 3), (2004, 2009, 7),
            (2009, 2024, 7)],  # B-/CCC range
    'VNM': [(2002, 2024, 10)],  # BB
    'POL': [(1995, 2007, 14), (2007, 2024, 17)],  # A-
    'HUN': [(1996, 2006, 17), (2006, 2012, 13), (2012, 2024, 13)],  # BBB
    'ROU': [(1996, 2005, 10), (2005, 2024, 13)],  # BBB-
    'GRC': [(1990, 2001, 17), (2001, 2009, 17), (2009, 2010, 14),
            (2010, 2012, 3), (2012, 2018, 7), (2018, 2024, 13)],
    'PRT': [(1990, 2005, 19), (2005, 2011, 17), (2011, 2012, 10),
            (2012, 2019, 13), (2019, 2024, 13)],  # BBB
    'CYP': [(2001, 2011, 17), (2011, 2013, 3), (2013, 2018, 10),
            (2018, 2024, 14)],  # BBB+
    'UKR': [(2001, 2008, 8), (2008, 2010, 3), (2010, 2014, 7),
            (2014, 2015, 3), (2015, 2022, 7), (2022, 2024, 3)],
    'KEN': [(2006, 2024, 7)],  # B/B+
    'GHA': [(2003, 2009, 8), (2009, 2014, 7), (2014, 2022, 7),
            (2022, 2024, 3)],  # CCC/SD
    'SEN': [(2000, 2024, 8)],  # B+
    'CIV': [(2015, 2024, 10)],  # BB-
    'MAR': [(2007, 2024, 13)],  # BBB-
    'TUN': [(2002, 2012, 13), (2012, 2024, 7)],  # B
    'JOR': [(2003, 2024, 10)],  # BB-
    'LKA': [(2005, 2020, 7), (2020, 2022, 3), (2022, 2024, 0)],  # SD
    'BGD': [(2010, 2024, 10)],  # BB-
    'KAZ': [(2002, 2024, 13)],  # BBB
    'AZE': [(2005, 2024, 13)],  # BBB-
    'BHR': [(2002, 2016, 13), (2016, 2024, 8)],  # B+
    'OMN': [(2007, 2020, 13), (2020, 2024, 10)],  # BB
    'CRI': [(2002, 2024, 10)],  # BB-/BB
    'PAN': [(2000, 2024, 13)],  # BBB range
    'URY': [(1995, 2002, 13), (2002, 2003, 7), (2003, 2012, 10),
            (2012, 2024, 13)],  # BBB
    'DOM': [(2001, 2003, 10), (2003, 2005, 3), (2005, 2024, 10)],  # BB-
    'GTM': [(2001, 2024, 10)],  # BB-/BB
    'HND': [(2004, 2024, 10)],  # BB-
    'SLV': [(2000, 2017, 10), (2017, 2020, 7), (2020, 2024, 7)],  # B-
    'JAM': [(2000, 2010, 7), (2010, 2013, 3), (2013, 2024, 10)],  # BB-
    'TTO': [(2000, 2016, 17), (2016, 2024, 13)],  # BBB
    'BWA': [(2001, 2024, 17)],  # A-
    'MUS': [(2004, 2024, 13)],  # BBB+
    'NAM': [(2005, 2014, 13), (2014, 2024, 10)],  # BB+
}


def build_extended_ratings_panel(years):
    """Build ratings panel combining safe_assets + EM histories."""
    # Start with safe_assets ratings (AA- and above histories)
    ratings = build_ratings_panel(years)

    # Add EM histories
    em_records = []
    for iso3, periods in RATING_HISTORY_EM.items():
        for start, end, rating_num in periods:
            for yr in range(max(start, years[0]), min(end, years[-1]) + 1):
                em_records.append({
                    'iso3': iso3,
                    'year': yr,
                    'rating_numeric': rating_num,
                    'safe_issuer': int(rating_num >= SAFE_THRESHOLD),
                })

    em_df = pd.DataFrame(em_records)
    em_df = em_df.sort_values(['iso3', 'year', 'rating_numeric']).drop_duplicates(
        subset=['iso3', 'year'], keep='last')

    # Combine, letting EM override if duplicates
    combined = pd.concat([ratings, em_df], ignore_index=True)
    combined = combined.sort_values(['iso3', 'year', 'rating_numeric']).drop_duplicates(
        subset=['iso3', 'year'], keep='last')

    return combined


def main():
    print("=" * 70)
    print("PHASE 1: Data Assembly — Sovereign Spreads Panel")
    print("=" * 70)

    # ── [1] Load full_panel.csv ──
    print("\n[1] Loading full_panel.csv ...")
    full_path = MULTILATERAL_DIR / "followup" / "data" / "processed" / "full_panel.csv"
    df = pd.read_csv(full_path)
    df = df[df['year'] <= 2024].copy()
    print(f"  Full panel: {df['iso3'].nunique()} countries, {len(df):,} obs")

    # ── [2] Build extended ratings panel ──
    print("\n[2] Building extended ratings panel ...")
    years = list(range(1990, 2025))
    ratings = build_extended_ratings_panel(years)
    rated_countries = sorted(ratings['iso3'].unique())
    print(f"  Rated countries: {len(rated_countries)}")
    print(f"  Ratings obs: {len(ratings):,}")

    # Merge ratings into full panel
    df = df.merge(ratings[['iso3', 'year', 'rating_numeric', 'safe_issuer']],
                  on=['iso3', 'year'], how='left')

    # ── [3] Construct sovereign spread variables ──
    print("\n[3] Constructing spread variables ...")

    # (a) Sovereign spread: govt_bond_10y - GDP-weighted world average
    yield_data = df.dropna(subset=['govt_bond_10y', 'ngdp_usd'])
    world_yield = yield_data.groupby('year').apply(
        lambda g: np.average(g['govt_bond_10y'], weights=g['ngdp_usd'])
    ).rename('world_yield_10y')
    df = df.merge(world_yield.reset_index(), on='year', how='left')
    df['sovereign_spread'] = df['govt_bond_10y'] - df['world_yield_10y']
    n_spread = df['sovereign_spread'].notna().sum()
    print(f"  sovereign_spread (10y - world avg): {n_spread:,} obs, "
          f"{df.loc[df['sovereign_spread'].notna(), 'iso3'].nunique()} countries")

    # (b) US-referenced spread (standard for EM studies)
    us_yield = df[df['iso3'] == 'USA'][['year', 'govt_bond_10y']].rename(
        columns={'govt_bond_10y': 'us_yield_10y'})
    df = df.merge(us_yield, on='year', how='left')
    df['spread_vs_us'] = df['govt_bond_10y'] - df['us_yield_10y']
    n_us = df['spread_vs_us'].notna().sum()
    print(f"  spread_vs_us (10y - US 10y): {n_us:,} obs")

    # (c) Germany-referenced spread (for European analysis)
    de_yield = df[df['iso3'] == 'DEU'][['year', 'govt_bond_10y']].rename(
        columns={'govt_bond_10y': 'de_yield_10y'})
    df = df.merge(de_yield, on='year', how='left')
    df['spread_vs_de'] = df['govt_bond_10y'] - df['de_yield_10y']
    n_de = df['spread_vs_de'].notna().sum()
    print(f"  spread_vs_de (10y - DEU 10y): {n_de:,} obs")

    # (d) Short-rate spread
    short_data = df.dropna(subset=['short_rate_3m', 'ngdp_usd'])
    world_short = short_data.groupby('year').apply(
        lambda g: np.average(g['short_rate_3m'], weights=g['ngdp_usd'])
    ).rename('world_short_3m')
    df = df.merge(world_short.reset_index(), on='year', how='left')
    df['short_spread'] = df['short_rate_3m'] - df['world_short_3m']
    n_short = df['short_spread'].notna().sum()
    print(f"  short_spread (3m - world avg): {n_short:,} obs")

    # ── [4] Load fiscal data ──
    print("\n[4] Loading fiscal data ...")
    fiscal_path = FISCAL_DIR / "data" / "processed" / "fiscal_panel.csv"
    if fiscal_path.exists():
        fisc = pd.read_csv(fiscal_path)
        fiscal_vars = ['govt_debt_gdp', 'govt_net_debt_gdp', 'primary_bal_gdp',
                       'structural_bal_gdp', 'govt_revenue_gdp', 'govt_expenditure_gdp',
                       'r_minus_g']
        existing = set(df.columns)
        merge_cols = ['iso3', 'year'] + [c for c in fiscal_vars
                                          if c in fisc.columns and c not in existing]
        if len(merge_cols) > 2:
            fisc_merge = fisc[merge_cols].drop_duplicates(subset=['iso3', 'year'])
            df = df.merge(fisc_merge, on=['iso3', 'year'], how='left')
            for v in merge_cols[2:]:
                n = df[v].notna().sum()
                print(f"  {v}: {n:,} non-null")
    else:
        print("  fiscal_panel.csv not found — skipping fiscal vars")

    # ── [5] Construct derived variables ──
    print("\n[5] Constructing derived variables ...")

    # Rating-related
    df['investment_grade'] = (df['rating_numeric'] >= 12).astype(float)
    df.loc[df['rating_numeric'].isna(), 'investment_grade'] = np.nan

    # Rating change (lagged)
    df = df.sort_values(['iso3', 'year'])
    df['rating_lag'] = df.groupby('iso3')['rating_numeric'].shift(1)
    df['rating_change'] = df['rating_numeric'] - df['rating_lag']
    df['downgrade_any'] = ((df['rating_change'] < 0) &
                            df['rating_change'].notna()).astype(int)

    # Fiscal stress
    if 'govt_debt_gdp' in df.columns:
        df['debt_lag5'] = df.groupby('iso3')['govt_debt_gdp'].shift(5)
        df['debt_change_5y'] = df['govt_debt_gdp'] - df['debt_lag5']
    if 'govt_expenditure_gdp' in df.columns and 'govt_revenue_gdp' in df.columns:
        df['exp_rev_gap'] = df['govt_expenditure_gdp'] - df['govt_revenue_gdp']

    # Lagged spread for dynamic models
    df['sovereign_spread_lag'] = df.groupby('iso3')['sovereign_spread'].shift(1)
    df['spread_vs_us_lag'] = df.groupby('iso3')['spread_vs_us'].shift(1)

    # Log spread (for EM with positive spreads)
    df['log_spread_vs_us'] = np.where(
        df['spread_vs_us'] > 0,
        np.log(df['spread_vs_us']),
        np.nan
    )

    # OECD dummy
    df['oecd'] = df['iso3'].isin(OECD_38).astype(int)

    # Income groups (from GDP per capita)
    if 'gdp_pc_ppp' in df.columns:
        def safe_qcut(x):
            try:
                return pd.qcut(x, 3, labels=['low', 'mid', 'high'], duplicates='drop')
            except (ValueError, IndexError):
                return pd.Series(np.nan, index=x.index)
        df['income_group'] = df.groupby('year')['gdp_pc_ppp'].transform(safe_qcut)

    # Predetermined demographics
    if 'old_dep' in df.columns:
        df['oadr_plus10'] = df.groupby('iso3')['old_dep'].shift(-10)
        df['oadr_plus20'] = df.groupby('iso3')['old_dep'].shift(-20)

    # Demographic × fiscal interactions
    for zv in ['Z_1', 'Z_2', 'Z_3']:
        if zv in df.columns:
            if 'govt_debt_gdp' in df.columns:
                df[f'{zv}_x_debt'] = df[zv] * df['govt_debt_gdp']
            if 'fiscal_bal_gdp' in df.columns:
                df[f'{zv}_x_fiscal'] = df[zv] * df['fiscal_bal_gdp']

    # 5-year lagged demographics
    for zv in ['Z_1', 'Z_2', 'Z_3']:
        if zv in df.columns:
            df[f'{zv}_lag5'] = df.groupby('iso3')[zv].shift(5)

    # First-differenced demographics
    for zv in ['Z_1', 'Z_2', 'Z_3']:
        if zv in df.columns:
            df[f'd_{zv}'] = df.groupby('iso3')[zv].diff()

    # ── [6] Restrict to years with spread data ──
    print("\n[6] Panel summary ...")
    df = df[(df['year'] >= 1990) & (df['year'] <= 2024)].copy()

    n_total = len(df)
    n_countries = df['iso3'].nunique()
    n_rated = df.loc[df['rating_numeric'].notna(), 'iso3'].nunique()
    n_with_spread = df.loc[df['sovereign_spread'].notna(), 'iso3'].nunique()
    n_with_us_spread = df.loc[df['spread_vs_us'].notna(), 'iso3'].nunique()

    print(f"  Total panel: {n_countries} countries, {n_total:,} obs")
    print(f"  Countries with ratings: {n_rated}")
    print(f"  Countries with 10y spread: {n_with_spread}")
    print(f"  Countries with US spread: {n_with_us_spread}")

    # ── [7] Summary statistics table ──
    print("\n[7] Building summary statistics ...")
    key_vars = ['sovereign_spread', 'spread_vs_us', 'spread_vs_de',
                'rating_numeric', 'govt_bond_10y', 'short_rate_3m',
                'Z_1', 'Z_2', 'Z_3', 'old_dep', 'youth_dep',
                'fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth',
                'kaopen', 'govt_debt_gdp', 'trade_openness']
    key_vars = [v for v in key_vars if v in df.columns]

    md = ["# Summary Statistics — Sovereign Spreads Panel\n"]
    md.append("| Variable | N | Mean | SD | Min | Max |")
    md.append("|---|---|---|---|---|---|")
    for v in key_vars:
        s = df[v].dropna()
        if len(s) > 0:
            md.append(f"| {v} | {len(s):,} | {s.mean():.3f} | {s.std():.3f} "
                      f"| {s.min():.3f} | {s.max():.3f} |")
    md.append(f"\n*Panel: {n_countries} countries, {df['year'].min()}-{df['year'].max()}*")

    out = TABLES_DIR / "summary_statistics.md"
    out.write_text('\n'.join(md))
    print(f"  Saved: {out}")

    # ── [8] Save ──
    print("\n[8] Saving spread_panel.csv ...")
    df.to_csv(PROCESSED_DIR / "spread_panel.csv", index=False)
    print(f"  Saved: {PROCESSED_DIR / 'spread_panel.csv'}")
    print(f"  Shape: {df.shape[0]:,} obs x {df.shape[1]} columns")

    print("\n" + "=" * 70)
    print("Phase 1 complete.")
    print("=" * 70)

    return df


if __name__ == "__main__":
    df = main()
